import os
import time
import re
import shlex
import argparse
import subprocess
import csv

# msrb_fragmenter_batch.py
# Run batch job of MSRB-Fragmenter using the Slurm grid scheduler.
# David Arndt, Feb 2020

def run_argument_parser():
    parser = argparse.ArgumentParser(description='Batch MSRB-Fragmenter')
    parser.add_argument('-i', help='Input .tsv file')
    parser.add_argument('-o', help= 'Output file')
    parser.add_argument('-a', help= 'Array id') # For running with sbatch on Compute Canada
    parser.add_argument('-s', help= 'Array size, i.e. split input list into this many Slurm jobs')
#     parser.add_argument('-c', help= 'Spectra Type/Charge: positive, negative, or ei')
    return parser

if __name__ == "__main__":
    
    args = run_argument_parser().parse_args()
    
    if args.i == None:
        print("Missing input file.")
        exit()
    if args.o == None:
        print("Missing output file.")
        exit()
    if args.a == None:
        print("Missing array id.")
        exit()
    if args.s == None:
        print("Missing array size.")
        exit()
    
    infile = args.i
    outfile = args.o + '_' + args.a
    
    pattern = re.compile('STATUS REPORT = (\d+)')
    
    start_time = time.time()
    out = open(outfile, "w")
    
    # Determine which lines from input file we should process.
    array_id = int(args.a)
    array_size = int(args.s)
    
    num_input_lines = 0
    with open(infile) as tsvfile:
        reader = csv.reader(tsvfile, delimiter='\t')
        for row in reader:
            if len(row) > 0:
                num_input_lines = num_input_lines + 1
#     print(num_input_lines)
    
    start_idx = (float(array_id-1)/float(array_size))*num_input_lines
    end_idx = (float(array_id)/float(array_size))*num_input_lines
        # we will process lines start_idx to (end_idx - 1), numbered starting from 0
    
#     out.write("%d %d %d\n" % (num_input_lines, start_idx, end_idx))
#     print("%d %d %d\n" % (num_input_lines, start_idx, end_idx))
    
    idx = 0
    processed = 0
    with open(infile) as tsvfile:
        reader = csv.reader(tsvfile, delimiter='\t')
        for row in reader:
#             print('idx: ' + str(idx))
            if idx >= start_idx and idx < end_idx and len(row) > 0:
                
#                 print(idx)
                
                inchikey = row[0]
                smiles = row[1]
                cfmid_out = inchikey + '.txt' # without extension, program will crash
                
                
                cmd =("java -jar msrb-fragmenter.jar -ismi '{0}' -o results/{1}").format(smiles, cfmid_out)
                args = shlex.split(cmd) # Use this as argument to Popen if using shell=False
                my_env = os.environ.copy()
                my_env["MALLOC_ARENA_MAX"] = '4' # See https://stackoverflow.com/questions/33793620/java-what-determines-the-maximum-max-heap-size-possible-in-a-linux-machine
                p = subprocess.Popen(args, stdout=subprocess.PIPE, \
                                    stderr=subprocess.PIPE, shell=False, env=my_env)
                stdout, stderr = p.communicate()
                stdout_str = stdout.decode('utf-8')
                stderr_str = stderr.decode('utf-8')
                rc = p.returncode
                
                if rc == 0:
                    match = pattern.search(stdout_str)
                    if match:
                        status_code = match.group(1)
        #                 print(">>>" + status_code)
                        out.write(inchikey + '\t' + smiles + '\t' + status_code + '\n')
                    else:
                        out.write(inchikey + '\t' + smiles + '\t' + '-' + '\n')
                        err_log = open("msrbfrag_no_status_{0}.log".format(idx), "w")
                        err_log.write('No error code returned from MSRB-Fragmenter\n')
                        err_log.write('idx: ' + str(idx) + '\n')
                        err_log.write('STDOUT:\n')
                        err_log.write(stdout_str + '\n')
                        err_log.write('STDERR:\n')
                        err_log.write(stderr_str + '\n')
                        err_log.close()
                else:
                    out.write(inchikey + '\t' + smiles + '\t' + str(-rc) + '\n')
                
                processed = processed + 1
                if processed % 500 == 0:
                    curr_time = time.time()
                    elapsed = curr_time - start_time
                    print ("%d done in %.0f s" % (processed, elapsed))
            
            idx = idx + 1
    
    out.close()
    
    end_time = time.time()
    elapsed = end_time - start_time
    print ("Processed compounds %d-%d in %.0f s" % (start_idx, end_idx-1, elapsed))
